{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "a9FLXSQwKy07"
   },
   "source": [
    "# Traduction (PyTorch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KBQHiJybKy0-"
   },
   "source": [
    "Installez les bibliothèques 🤗 *Datasets* et 🤗 *Transformers* pour exécuter ce *notebook*."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mBCiD3lkKy0_"
   },
   "outputs": [],
   "source": [
    "!pip install datasets transformers[sentencepiece]\n",
    "!pip install accelerate\n",
    "# Pour exécuter l'entraînement sur TPU, vous devez décommenter la ligne suivante :\n",
    "# !pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl\n",
    "!apt install git-lfs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jPmvTXi1Ky1A"
   },
   "source": [
    "Vous aurez besoin de configurer git, adaptez votre email et votre nom dans la cellule suivante."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "WRi1-zwbKy1B"
   },
   "outputs": [],
   "source": [
    "!git config --global user.email \"you@example.com\"\n",
    "!git config --global user.name \"Your Name\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6_uvLx_cKy1C"
   },
   "source": [
    "Vous devrez également être connecté au Hub d'Hugging Face. Exécutez ce qui suit et entrez vos informations d'identification."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6A8j7U20Ky1D"
   },
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "f7jZ0kj2Ky1D"
   },
   "outputs": [],
   "source": [
    "from datasets import load_dataset, load_metric\n",
    "\n",
    "raw_datasets = load_dataset(\"kde4\", lang1=\"en\", lang2=\"fr\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JPzEbB9CKy1E"
   },
   "outputs": [],
   "source": [
    "raw_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Y4XIrw4MKy1G"
   },
   "outputs": [],
   "source": [
    "split_datasets = raw_datasets[\"train\"].train_test_split(train_size=0.9, seed=20)\n",
    "split_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "pOlt43_VKy1H"
   },
   "outputs": [],
   "source": [
    "split_datasets[\"validation\"] = split_datasets.pop(\"test\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "jvfc4756Ky1H"
   },
   "outputs": [],
   "source": [
    "split_datasets[\"train\"][1][\"translation\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Oq8MNDKKKy1I"
   },
   "outputs": [],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "model_checkpoint = \"Helsinki-NLP/opus-mt-en-fr\"\n",
    "translator = pipeline(\"translation\", model=model_checkpoint)\n",
    "translator(\"Default to expanded threads\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "G5J86Q66Ky1J"
   },
   "outputs": [],
   "source": [
    "split_datasets[\"train\"][172][\"translation\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "77Ra7T93Ky1J"
   },
   "outputs": [],
   "source": [
    "translator(\n",
    "    \"Unable to import %1 using the OFX importer plugin. This file is not the correct format.\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "uyAiddy7Ky1K"
   },
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "model_checkpoint = \"Helsinki-NLP/opus-mt-en-fr\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, return_tensors=\"tf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "J1LKf1zlKy1K"
   },
   "outputs": [],
   "source": [
    "en_sentence = split_datasets[\"train\"][1][\"translation\"][\"en\"]\n",
    "fr_sentence = split_datasets[\"train\"][1][\"translation\"][\"fr\"]\n",
    "\n",
    "inputs = tokenizer(en_sentence)\n",
    "with tokenizer.as_target_tokenizer():\n",
    "    targets = tokenizer(fr_sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "56xilSMiKy1L"
   },
   "outputs": [],
   "source": [
    "wrong_targets = tokenizer(fr_sentence)\n",
    "print(tokenizer.convert_ids_to_tokens(wrong_targets[\"input_ids\"]))\n",
    "print(tokenizer.convert_ids_to_tokens(targets[\"input_ids\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xY6wDyYmKy1M"
   },
   "outputs": [],
   "source": [
    "max_input_length = 128\n",
    "max_target_length = 128\n",
    "\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    inputs = [ex[\"en\"] for ex in examples[\"translation\"]]\n",
    "    targets = [ex[\"fr\"] for ex in examples[\"translation\"]]\n",
    "    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)\n",
    "\n",
    "    # Configurer le tokenizer pour les cibles\n",
    "    with tokenizer.as_target_tokenizer():\n",
    "        labels = tokenizer(targets, max_length=max_target_length, truncation=True)\n",
    "\n",
    "    model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
    "    return model_inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fklE6dSFKy1N"
   },
   "outputs": [],
   "source": [
    "tokenized_datasets = split_datasets.map(\n",
    "    preprocess_function,\n",
    "    batched=True,\n",
    "    remove_columns=split_datasets[\"train\"].column_names,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3OPavkGXKy1N"
   },
   "outputs": [],
   "source": [
    "from transformers import AutoModelForSeq2SeqLM\n",
    "\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5hblEiLDKy1N"
   },
   "outputs": [],
   "source": [
    "from transformers import DataCollatorForSeq2Seq\n",
    "\n",
    "data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6UlKyTObKy1N"
   },
   "outputs": [],
   "source": [
    "batch = data_collator([tokenized_datasets[\"train\"][i] for i in range(1, 3)])\n",
    "batch.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "_wrd_1laKy1O"
   },
   "outputs": [],
   "source": [
    "batch[\"labels\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "QGTJBtLxKy1P"
   },
   "outputs": [],
   "source": [
    "batch[\"decoder_input_ids\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "g1Nu16LkKy1P"
   },
   "outputs": [],
   "source": [
    "for i in range(1, 3):\n",
    "    print(tokenized_datasets[\"train\"][i][\"labels\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "D-V_-SfwKy1P"
   },
   "outputs": [],
   "source": [
    "!pip install sacrebleu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "USx5whYfKy1Q"
   },
   "outputs": [],
   "source": [
    "from datasets import load_metric\n",
    "\n",
    "metric = load_metric(\"sacrebleu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "rOwFYZxmKy1Q"
   },
   "outputs": [],
   "source": [
    "predictions = [\n",
    "    \"This plugin lets you translate web pages between several languages automatically.\"\n",
    "]\n",
    "references = [\n",
    "    [\n",
    "        \"This plugin allows you to automatically translate web pages between several languages.\"\n",
    "    ]\n",
    "]\n",
    "metric.compute(predictions=predictions, references=references)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "uIWAWo5mKy1Q"
   },
   "outputs": [],
   "source": [
    "predictions = [\"This This This This\"]\n",
    "references = [\n",
    "    [\n",
    "        \"This plugin allows you to automatically translate web pages between several languages.\"\n",
    "    ]\n",
    "]\n",
    "metric.compute(predictions=predictions, references=references)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bLqFm4fFKy1R"
   },
   "outputs": [],
   "source": [
    "predictions = [\"This plugin\"]\n",
    "references = [\n",
    "    [\n",
    "        \"This plugin allows you to automatically translate web pages between several languages.\"\n",
    "    ]\n",
    "]\n",
    "metric.compute(predictions=predictions, references=references)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "rx8u9pDsKy1R"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def compute_metrics(eval_preds):\n",
    "    preds, labels = eval_preds\n",
    "    # Dans le cas où le modèle retourne plus que les logits de prédiction\n",
    "    if isinstance(preds, tuple):\n",
    "        preds = preds[0]\n",
    "\n",
    "    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
    "\n",
    "    # Remplacer les -100 dans les étiquettes car nous ne pouvons pas les décoder\n",
    "    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
    "    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
    "\n",
    "    # Quelques post-traitements simples\n",
    "    decoded_preds = [pred.strip() for pred in decoded_preds]\n",
    "    decoded_labels = [[label.strip()] for label in decoded_labels]\n",
    "\n",
    "    result = metric.compute(predictions=decoded_preds, references=decoded_labels)\n",
    "    return {\"bleu\": result[\"score\"]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VJrt4wY3Ky1S"
   },
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6rvcbMenKy1S"
   },
   "outputs": [],
   "source": [
    "from transformers import Seq2SeqTrainingArguments\n",
    "\n",
    "args = Seq2SeqTrainingArguments(\n",
    "    f\"marian-finetuned-kde4-en-to-fr\",\n",
    "    evaluation_strategy=\"no\",\n",
    "    save_strategy=\"epoch\",\n",
    "    learning_rate=2e-5,\n",
    "    per_device_train_batch_size=32,\n",
    "    per_device_eval_batch_size=64,\n",
    "    weight_decay=0.01,\n",
    "    save_total_limit=3,\n",
    "    num_train_epochs=3,\n",
    "    predict_with_generate=True,\n",
    "    fp16=True,\n",
    "    push_to_hub=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "M7scgmEvKy1T"
   },
   "outputs": [],
   "source": [
    "from transformers import Seq2SeqTrainer\n",
    "\n",
    "trainer = Seq2SeqTrainer(\n",
    "    model,\n",
    "    args,\n",
    "    train_dataset=tokenized_datasets[\"train\"],\n",
    "    eval_dataset=tokenized_datasets[\"validation\"],\n",
    "    data_collator=data_collator,\n",
    "    tokenizer=tokenizer,\n",
    "    compute_metrics=compute_metrics,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "cgZHNFuCKy1T"
   },
   "outputs": [],
   "source": [
    "trainer.evaluate(max_length=max_target_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JTcbgpizKy1T"
   },
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "a5PukEUnKy1U"
   },
   "outputs": [],
   "source": [
    "trainer.evaluate(max_length=max_target_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "FPcIyn99Ky1U"
   },
   "outputs": [],
   "source": [
    "trainer.push_to_hub(tags=\"translation\", commit_message=\"Training complete\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fnkezZjWKy1V"
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "tokenized_datasets.set_format(\"torch\")\n",
    "train_dataloader = DataLoader(\n",
    "    tokenized_datasets[\"train\"],\n",
    "    shuffle=True,\n",
    "    collate_fn=data_collator,\n",
    "    batch_size=8,\n",
    ")\n",
    "eval_dataloader = DataLoader(\n",
    "    tokenized_datasets[\"validation\"], collate_fn=data_collator, batch_size=8\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "N8HHRxGfKy1V"
   },
   "outputs": [],
   "source": [
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "RWKO6klrKy1V"
   },
   "outputs": [],
   "source": [
    "from transformers import AdamW\n",
    "\n",
    "optimizer = AdamW(model.parameters(), lr=2e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PCtEyEauKy1W"
   },
   "outputs": [],
   "source": [
    "from accelerate import Accelerator\n",
    "\n",
    "accelerator = Accelerator()\n",
    "model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(\n",
    "    model, optimizer, train_dataloader, eval_dataloader\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PlMSEpINKy1W"
   },
   "outputs": [],
   "source": [
    "from transformers import get_scheduler\n",
    "\n",
    "num_train_epochs = 3\n",
    "num_update_steps_per_epoch = len(train_dataloader)\n",
    "num_training_steps = num_train_epochs * num_update_steps_per_epoch\n",
    "\n",
    "lr_scheduler = get_scheduler(\n",
    "    \"linear\",\n",
    "    optimizer=optimizer,\n",
    "    num_warmup_steps=0,\n",
    "    num_training_steps=num_training_steps,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BpcWbqnbKy1X"
   },
   "outputs": [],
   "source": [
    "from huggingface_hub import Repository, get_full_repo_name\n",
    "\n",
    "model_name = \"marian-finetuned-kde4-en-to-fr-accelerate\"\n",
    "repo_name = get_full_repo_name(model_name)\n",
    "repo_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "erxPV1tXKy1X"
   },
   "outputs": [],
   "source": [
    "output_dir = \"marian-finetuned-kde4-en-to-fr-accelerate\"\n",
    "repo = Repository(output_dir, clone_from=repo_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "cSrda_kiKy1X"
   },
   "outputs": [],
   "source": [
    "def postprocess(predictions, labels):\n",
    "    predictions = predictions.cpu().numpy()\n",
    "    labels = labels.cpu().numpy()\n",
    "\n",
    "    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)\n",
    "\n",
    "    # Remplacez -100 dans les étiquettes car nous ne pouvons pas les décoder\n",
    "    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
    "    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
    "\n",
    "    # Quelques post-traitements simples\n",
    "    decoded_preds = [pred.strip() for pred in decoded_preds]\n",
    "    decoded_labels = [[label.strip()] for label in decoded_labels]\n",
    "    return decoded_preds, decoded_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "hs-rK7J7Ky1Y"
   },
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "import torch\n",
    "\n",
    "progress_bar = tqdm(range(num_training_steps))\n",
    "\n",
    "for epoch in range(num_train_epochs):\n",
    "    # Entraînement\n",
    "    model.train()\n",
    "    for batch in train_dataloader:\n",
    "        outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        accelerator.backward(loss)\n",
    "\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "        optimizer.zero_grad()\n",
    "        progress_bar.update(1)\n",
    "\n",
    "    # Evaluation\n",
    "    model.eval()\n",
    "    for batch in tqdm(eval_dataloader):\n",
    "        with torch.no_grad():\n",
    "            generated_tokens = accelerator.unwrap_model(model).generate(\n",
    "                batch[\"input_ids\"],\n",
    "                attention_mask=batch[\"attention_mask\"],\n",
    "                max_length=128,\n",
    "            )\n",
    "        labels = batch[\"labels\"]\n",
    "\n",
    "        # Nécessaire pour rembourrer les prédictions et les étiquettes à rassembler\n",
    "        generated_tokens = accelerator.pad_across_processes(\n",
    "            generated_tokens, dim=1, pad_index=tokenizer.pad_token_id\n",
    "        )\n",
    "        labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100)\n",
    "\n",
    "        predictions_gathered = accelerator.gather(generated_tokens)\n",
    "        labels_gathered = accelerator.gather(labels)\n",
    "\n",
    "        decoded_preds, decoded_labels = postprocess(predictions_gathered, labels_gathered)\n",
    "        metric.add_batch(predictions=decoded_preds, references=decoded_labels)\n",
    "\n",
    "    results = metric.compute()\n",
    "    print(f\"epoch {epoch}, BLEU score: {results['score']:.2f}\")\n",
    "\n",
    "    # Sauvegarder et télécharger\n",
    "    accelerator.wait_for_everyone()\n",
    "    unwrapped_model = accelerator.unwrap_model(model)\n",
    "    unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)\n",
    "    if accelerator.is_main_process:\n",
    "        tokenizer.save_pretrained(output_dir)\n",
    "        repo.push_to_hub(\n",
    "            commit_message=f\"Training in progress epoch {epoch}\", blocking=False\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5q2sbBk6Ky1Y"
   },
   "outputs": [],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "# Remplacer par votre propre checkpoint\n",
    "model_checkpoint = \"huggingface-course/marian-finetuned-kde4-en-to-fr\"\n",
    "translator = pipeline(\"translation\", model=model_checkpoint)\n",
    "translator(\"Default to expanded threads\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "oO_8WGkfKy1Z"
   },
   "outputs": [],
   "source": [
    "translator(\n",
    "    \"Unable to import %1 using the OFX importer plugin. This file is not the correct format.\"\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
